{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Introduction to Lingvo",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/tensorflow/lingvo/blob/master/codelabs/introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "od97_nvR82Qo"
      },
      "source": [
        "# Introduction to Lingvo\n",
        "\n",
        "This codelab will guide you through the implementation of a **sequence-to-sequence** model using [**Lingvo**](https://github.com/tensorflow/lingvo).\n",
        "\n",
        "**Sequence-to-sequence** models map input sequences of arbitrary length to\n",
        "output sequences of arbitrary length. Example uses of sequence-to-sequence\n",
        "models include machine translation, which maps a sequence of words from one\n",
        "language into a sequence of words in another language with the same meaning;\n",
        "speech recognition, which maps a sequence of acoustic features into a sequence\n",
        "of words; and text summarization, which\n",
        "maps a sequence of words into a shorter sequence which conveys the same meaning.\n",
        "\n",
        "In this codelab, you will create a model which restores punctuation and\n",
        "capitalization to text which has been lowercased and stripped of punctuation.\n",
        "For example, given the following text:\n",
        "\n",
        "> she asked do you know the way to san jose\n",
        "\n",
        "The model will output the following properly-punctuated-and-capitalized text:\n",
        "\n",
        "> She asked, \"Do know you know the way to San Jose\"?\n",
        "\n",
        "We will train an RNMT+ model based off of [\"The Best of Both Worlds: Combining Recent Advances in Neural Machine Translation. (Chen et al., 2018)\"](https://arxiv.org/abs/1804.09849)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4bPGJjZkLBdz"
      },
      "source": [
        "## Table of Contents\n",
        "\n",
        "In Colab, click `[View]->[Table of contents]` on the menu bar."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bnpyjXz1OFAu"
      },
      "source": [
        "## Prerequisites\n",
        "\n",
        "The main goal of this codelab is to teach you how to define and train sequence-to-sequence models in Lingvo. We do not aim to teach either Python or Tensorflow, and no sophisticated Python or Tensorflow programming will be required. However, the following will be helpful in understanding this codelab:\n",
        "\n",
        "-   Familiarity with high-level machine learning primitives, in particular,\n",
        "    recurrent neural networks, LSTMs, and attention.\n",
        "-   Comfort reading and writing simple Python code. In particular, you should\n",
        "    know how to define simple classes and how inheritance works.\n",
        "-   Basic knowledge of the Tensorflow training workflow.  If you have trained\n",
        "    simple Tensorflow models before (e.g., via another codelab), you should know\n",
        "    enough for this codelab.\n",
        "\n",
        "### Resources\n",
        "\n",
        "- [Introduction to RNNs and LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)\n",
        "- [Official Tensorflow Tutorials](https://www.tensorflow.org/tutorials/)\n",
        "- For more advanced topics or to get a deeper understanding of Lingvo beyond this codelab, see the [paper](https://arxiv.org/abs/1902.08295)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "coa-U7N6Kunb"
      },
      "source": [
        "## Learning Objectives\n",
        "\n",
        "This codelab will teach you the following:\n",
        "\n",
        "-   How to generate input data for training a sequence-to-sequence model in Lingvo.\n",
        "-   How models are specified and configured in Lingvo, by adapting a pre-existing model architecture for machine translation.\n",
        "-   How to use the trained model for inference.\n",
        "\n",
        "This codelab does not:\n",
        "\n",
        "-   Teach you how to design a model for solving specific tasks.\n",
        "-   Provide a state-of-the-art model for the punctuator task."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_mtJSYor9d6d"
      },
      "source": [
        "## Environment Setup\n",
        "\n",
        "To start with, we need to connect this Colab notebook with Lingvo.\n",
        "\n",
        "```shell\n",
        "mkdir -p /tmp/lingvo_codelab && cd /tmp/lingvo_codelab\n",
        "pip3 install lingvo\n",
        "python3 -m lingvo.ipython_kernel\n",
        "```\n",
        "\n",
        "Finally, on the top right hand side of this Colab notebook, open the dropdown beside \"CONNECT\" and select \"Connect to local runtime...\", enter `http://localhost:8888` and press CONNECT.\n",
        "\n",
        "You should now see the words \"CONNECTED\" and be able to execute the following cell."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2JRfZTumHxpY",
        "colab": {}
      },
      "source": [
        "import lingvo"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lkN9qajvNtEs"
      },
      "source": [
        "## Input Pipeline\n",
        "\n",
        "In order to train a sequence-to-sequence model, we need a set of pairs of source\n",
        "and target sequences. For this codelab, our source sequences will be\n",
        "text which has been lowercased and had its punctuation removed, and the target\n",
        "sequences will be the original sentences, with their original casing and\n",
        "punctuation.\n",
        "\n",
        "Since neural networks require numeric inputs, we will also need a tokenization scheme mapping the sequence of characters to a sequence of numbers. In this codelab, we will use a pre-trained word-piece model.\n",
        "\n",
        "### Download Raw Input\n",
        "\n",
        "We will use the [Brown Corpus](http://www.nltk.org/nltk_data) as the source of our training data. Run the following cell to download and preprocess the dataset. The script will generate `train.txt` and `test.txt` containing the training and test data at an 80:20 split with individual sentences on each line."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Asfie_a8YmN3",
        "colab": {}
      },
      "source": [
        "!python3 -m lingvo.tasks.punctuator.tools.download_brown_corpus --outdir=/tmp/punctuator_data\n",
        "!curl -O https://raw.githubusercontent.com/tensorflow/lingvo/master/lingvo/tasks/punctuator/params/brown_corpus_wpm.16000.vocab"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BHq87NF-8lgD"
      },
      "source": [
        "### Define an Input Generator\n",
        "\n",
        "In order to train a model, we need an input generator that provides structured mini-batches of source-target pairs. The input generator handles all the processing necessary to generate numeric data that can be fed to the model. This includes:\n",
        "\n",
        "- reading examples from the data source in random order, where the data source may be split across multiple files;\n",
        "- processing the data -- for our task this involves generating a \"source\" sentence by converting all characters to lower-case and removing punctuation, and then tokenizing both the source and target sequences into integer tokens; and\n",
        "- batching together examples by padding them to the same length. Multiple buckets of different lengths may be used to avoid inefficiency from padding a short input to a very long length.\n",
        "\n",
        "Fortunately, the majority of this is handled in the background by Lingvo. We only need to specify how the data should be processed.\n",
        "\n",
        "Input generators are subclasses of *BaseInputGenerator* found in [lingvo/core/base_input_generator.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/base_input_generator.py) and have the following structure:\n",
        "\n",
        "- a `Params` classmethod that returns a default Params object for configuring the input generator. Experiment configurations inside Lingvo are controlled using these Params objects.\n",
        "- an `InputBatch` method that returns a [`NestedMap`](https://github.com/tensorflow/lingvo/blob/3344e201719961183d88713784ccbae447f5c52a/lingvo/core/py_utils.py#L392) containing the input batch. `NestedMap` is an arbitrarily nested map structure used throughout Lingvo.\n",
        "- an optional `PreprocessInputBatch` method that preprocesses the batch returned by `InputBatch`.\n",
        "\n",
        "Here is an example of the input generator for the Punctuator task, found in [lingvo/tasks/punctuator/input_generator.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/punctuator/input_generator.py).\n",
        "\n",
        "Run the cell below to write the file to disk."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "T2BzBCqE_yvt",
        "colab": {}
      },
      "source": [
        "%%writefile input_generator.py\n",
        "import string\n",
        "import lingvo.compat as tf\n",
        "from lingvo.core import base_input_generator\n",
        "from lingvo.core import base_layer\n",
        "from lingvo.core import generic_input\n",
        "from lingvo.core import py_utils\n",
        "from lingvo.core import tokenizers\n",
        "\n",
        "\n",
        "class PunctuatorInput(base_input_generator.BaseSequenceInputGenerator):\n",
        "  \"\"\"Reads text line by line and processes them for the punctuator task.\"\"\"\n",
        "\n",
        "  @classmethod\n",
        "  def Params(cls):\n",
        "    \"\"\"Defaults params for PunctuatorInput.\"\"\"\n",
        "    p = super(PunctuatorInput, cls).Params()\n",
        "    p.tokenizer = tokenizers.WpmTokenizer.Params()\n",
        "    return p\n",
        "\n",
        "  def _ProcessLine(self, line):\n",
        "    \"\"\"A single-text-line processor.\n",
        "    Gets a string tensor representing a line of text that have been read from\n",
        "    the input file, and splits it to graphemes (characters).\n",
        "    We use original characters as the target labels, and the lowercased and\n",
        "    punctuation-removed characters as the source labels.\n",
        "    Args:\n",
        "      line: a 1D string tensor.\n",
        "    Returns:\n",
        "      A list of tensors, in the expected order by __init__.\n",
        "    \"\"\"\n",
        "    # Tokenize the input into integer ids.\n",
        "    # tgt_ids has the start-of-sentence token prepended, and tgt_labels has the\n",
        "    # end-of-sentence token appended.\n",
        "    tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(\n",
        "        tf.convert_to_tensor([line]))\n",
        "\n",
        "    def Normalize(line):\n",
        "      # Lowercase and remove punctuation.\n",
        "      line = line.lower().translate(None, string.punctuation.encode('utf-8'))\n",
        "      # Convert multiple consecutive spaces to a single one.\n",
        "      line = b' '.join(line.split())\n",
        "      return line\n",
        "\n",
        "    normalized_line = tf.py_func(Normalize, [line], tf.string, stateful=False)\n",
        "    _, src_labels, src_paddings = self.StringsToIds(\n",
        "        tf.convert_to_tensor([normalized_line]), is_source=True)\n",
        "    # The model expects the source without a start-of-sentence token.\n",
        "    src_ids = src_labels\n",
        "\n",
        "    # Compute the length for bucketing.\n",
        "    bucket_key = tf.cast(\n",
        "        tf.round(\n",
        "            tf.maximum(\n",
        "                tf.reduce_sum(1.0 - src_paddings),\n",
        "                tf.reduce_sum(1.0 - tgt_paddings))), tf.int32)\n",
        "    tgt_weights = 1.0 - tgt_paddings\n",
        "\n",
        "    # Return tensors in an order consistent with __init__.\n",
        "    out_tensors = [\n",
        "        src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights\n",
        "    ]\n",
        "    return [tf.squeeze(t, axis=0) for t in out_tensors], bucket_key\n",
        "\n",
        "  def _DataSourceFromFilePattern(self, file_pattern):\n",
        "    \"\"\"Create the input processing op.\n",
        "    Args:\n",
        "      file_pattern: The file pattern to use as input.\n",
        "    Returns:\n",
        "      an operation that when executed, calls `_ProcessLine` on a line read\n",
        "    from `file_pattern`.\n",
        "    \"\"\"\n",
        "    return generic_input.GenericInput(\n",
        "        file_pattern=file_pattern,\n",
        "        processor=self._ProcessLine,\n",
        "        # Pad dimension 0 to the same length.\n",
        "        dynamic_padding_dimensions=[0] * 6,\n",
        "        # The constant values to use for padding each of the outputs.\n",
        "        dynamic_padding_constants=[0, 1, 0, 1, 0, 0],\n",
        "        **self.CommonInputOpArgs())\n",
        "\n",
        "  @base_layer.initializer\n",
        "  def __init__(self, params):\n",
        "    super(PunctuatorInput, self).__init__(params)\n",
        "\n",
        "    # Build the input processing graph.\n",
        "    (self._src_ids, self._src_paddings, self._tgt_ids, self._tgt_paddings,\n",
        "     self._tgt_labels,\n",
        "     self._tgt_weights), self._bucket_keys = self._BuildDataSource()\n",
        "\n",
        "    self._input_batch_size = tf.shape(self._src_ids)[0]\n",
        "    self._sample_ids = tf.range(0, self._input_batch_size, 1)\n",
        "\n",
        "  def InputBatch(self):\n",
        "    \"\"\"Returns a single batch as a `.NestedMap` to be passed to the model.\"\"\"\n",
        "    ret = py_utils.NestedMap()\n",
        "\n",
        "    ret.bucket_keys = self._bucket_keys\n",
        "\n",
        "    ret.src = py_utils.NestedMap()\n",
        "    ret.src.ids = tf.cast(self._src_ids, dtype=tf.int32)\n",
        "    ret.src.paddings = self._src_paddings\n",
        "\n",
        "    ret.tgt = py_utils.NestedMap()\n",
        "    ret.tgt.ids = self._tgt_ids\n",
        "    ret.tgt.labels = tf.cast(self._tgt_labels, dtype=tf.int32)\n",
        "    ret.tgt.weights = self._tgt_weights\n",
        "    ret.tgt.paddings = self._tgt_paddings\n",
        "\n",
        "    return ret"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lrwT1f0IfPca"
      },
      "source": [
        "## Model Definition\n",
        "\n",
        "Next, we need to define the network structure for the task. The network is a nested structure of layers. Most classes in Lingvo are subclasses of *BaseLayer* found in [lingvo/core/base_layer.py](https://github.com/tensorflow/lingvo/blob/560f838bd576c7b911df379121eb58252b6ae326/lingvo/core/base_layer.py#L150) and inherit the following:\n",
        "\n",
        "- a Params classmethod that returns a default [Params](https://github.com/tensorflow/lingvo/blob/3344e201719961183d88713784ccbae447f5c52a/lingvo/core/hyperparams.py#L151) object for configuring the class. In addition to hyperparameters, the Params object can also contain Params objects for configuring child layers. Some of the properties present in all Params objects include:\n",
        "  - `cls`: the python class that the Params object is associated with. This can be used to construct an instance of the class;\n",
        "  - `name`: the name of this layer;\n",
        "  - `dtype`: the default dtype to use when creating variables.\n",
        "- The `__init__` constructor. All child layers and variables should be created here.\n",
        "- A `CreateVariable` method that is called to create variables.\n",
        "- A `CreateChild` method that is called to create child layers.\n",
        "- A `FProp` method that implements forward propagation through the layer.\n",
        "\n",
        "As a reference, many examples of layers can be found in [lingvo/core/layers.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/layers.py), [lingvo/core/attention.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/attention.py), and [lingvo/core/rnn_layers.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_layers.py).\n",
        "\n",
        "&nbsp;\n",
        "\n",
        "The root layer for the network should be a subclass of `BaseTask` found in [lingvo/core/base_model.py](https://github.com/tensorflow/lingvo/blob/918c584f057481717eff6e1e29ae028aeab3d165/lingvo/core/base_model.py#L79), and implements the following:\n",
        "\n",
        "- A `ComputePredictions` method that takes the current variable values (`theta`) and `input_batch` and returns the network predictions.\n",
        "- A `ComputeLoss` method that takes `theta`, `input_batch`, and the `predictions` returned from `ComputePredictions` and returns a dictionary of scalar metrics, one of which should be `loss`. These scalar metrics are exported to TensorBoard as summaries.\n",
        "- An optional `Decode` method for creating a separate graph for decoding. For example, training and evaluation might use teacher forcing while decoding might not.\n",
        "- An optional `Inference` method that returns a graph with feeds and fetches that can be used together with a saved checkpoint for inference. This differs from `Decode` in that it can be fed data directly instead of using data from the input generator.\n",
        "\n",
        "&nbsp;\n",
        "\n",
        "This codelab uses the existing networks from [lingvo/tasks/punctuator/model.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/punctuator/model.py), which is derived from the networks in [lingvo/tasks/mt/model.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/mt/model.py) with an added `Inference` method for the punctuator task. The actual logic lies mostly in [lingvo/tasks/mt/encoder.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/mt/encoder.py) and [lingvo/tasks/mt/decoder.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/mt/decoder.py)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9ItaOgXNpql0"
      },
      "source": [
        "## Model Configuration\n",
        "\n",
        "After defining the input generator and the network, we need to create an model configuration with the specific hyperparameters to use for our model.\n",
        "\n",
        "Since there is only a single task, we create a subclass of `SingleTaskModelParams` found in [lingvo/core/base_model_params.py](https://github.com/tensorflow/lingvo/blob/4747cf80a7e6cf58211aa899bae854820a3b42f6/lingvo/core/base_model_params.py#L47). It has the following structure:\n",
        "\n",
        "- The `Train`/`Dev`/`Test` methods configure the input generator for the respective datasets.\n",
        "- The `Task` method configures the network.\n",
        "\n",
        "The following cell contains the configuration that will be used in this codelab. It can also be found in [lingvo/tasks/punctuator/params/codelab.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/punctuator/params/codelab.py). The network configuration in the `Task` classmethod is delegated to [lingvo/tasks/mt/params/base_config.py](https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/mt/params/base_config.py).\n",
        "\n",
        "Run the cell below to write the file to disk."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3CjruUvXY5we",
        "colab": {}
      },
      "source": [
        "%%writefile codelab.py\n",
        "import input_generator\n",
        "import os\n",
        "from lingvo import model_registry\n",
        "import lingvo.compat as tf\n",
        "from lingvo.core import base_model_params\n",
        "from lingvo.tasks.mt import base_config\n",
        "from lingvo.tasks.punctuator import model\n",
        "\n",
        "\n",
        "# This base class defines parameters for the input generator for a specific\n",
        "# dataset. Specific network architectures will be implemented in subclasses.\n",
        "class BrownCorpusWPM(base_model_params.SingleTaskModelParams):\n",
        "  \"\"\"Brown Corpus data with a Word-Piece Model tokenizer.\"\"\"\n",
        "\n",
        "  # Generated using\n",
        "  # lingvo/tasks/punctuator/tools:download_brown_corpus.\n",
        "  _DATADIR = '/tmp/punctuator_data'\n",
        "  _VOCAB_FILE = 'brown_corpus_wpm.16000.vocab'\n",
        "  # _VOCAB_SIZE needs to be a multiple of 16 because we use a sharded softmax\n",
        "  # with 16 shards.\n",
        "  _VOCAB_SIZE = 16000\n",
        "\n",
        "  def Train(self):\n",
        "    p = input_generator.PunctuatorInput.Params()\n",
        "    p.file_pattern = 'text:' + os.path.join(self._DATADIR, 'train.txt')\n",
        "    p.file_random_seed = 0  # Do not use a fixed seed.\n",
        "    p.file_parallelism = 1  # We only have a single input file.\n",
        "\n",
        "    # The bucket upper bound specifies how to split the input into buckets. We\n",
        "    # train on sequences up to maximum bucket size and discard longer examples.\n",
        "    p.bucket_upper_bound = [10, 20, 30, 60, 120]\n",
        "\n",
        "    # The bucket batch limit determines how many examples are there in each\n",
        "    # batch during training. We reduce the batch size for the buckets that\n",
        "    # have higher upper bound (batches that consist of longer sequences)\n",
        "    # in order to prevent out of memory issues.\n",
        "    # Note that this hyperparameter varies widely based on the model and\n",
        "    # language. Larger models may warrant smaller batches in order to fit in\n",
        "    # memory, for example; and ideographical languages like Chinese may benefit\n",
        "    # from more buckets.\n",
        "    p.bucket_batch_limit = [512, 256, 160, 80, 40]\n",
        "\n",
        "    p.tokenizer.vocab_filepath = self._VOCAB_FILE\n",
        "    p.tokenizer.vocab_size = self._VOCAB_SIZE\n",
        "    p.tokenizer.pad_to_max_length = False\n",
        "\n",
        "    # Set the tokenizer max length slightly longer than the largest bucket to\n",
        "    # discard examples that are longer than we allow.\n",
        "    p.source_max_length = p.bucket_upper_bound[-1] + 2\n",
        "    p.target_max_length = p.bucket_upper_bound[-1] + 2\n",
        "    return p\n",
        "\n",
        "  # There is also a Dev method for dev set params, but we don't have a dev set.\n",
        "  def Test(self):\n",
        "    p = input_generator.PunctuatorInput.Params()\n",
        "    p.file_pattern = 'text:' + os.path.join(self._DATADIR, 'test.txt')\n",
        "    p.file_random_seed = 27182818  # Fix random seed for testing.\n",
        "    # The following two parameters are important if there's more than one input\n",
        "    # file. For this codelab it doesn't actually matter.\n",
        "    p.file_parallelism = 1  # Avoid randomness in testing.\n",
        "    # In order to make exactly one pass over the dev/test sets, we set buffer\n",
        "    # size to 1. Greater numbers may cause inaccurate dev/test scores.\n",
        "    p.file_buffer_size = 1\n",
        "\n",
        "    p.bucket_upper_bound = [10, 20, 30, 60, 120, 200]\n",
        "    p.bucket_batch_limit = [16] * 4 + [4] * 2\n",
        "\n",
        "    p.tokenizer.vocab_filepath = self._VOCAB_FILE\n",
        "    p.tokenizer.vocab_size = self._VOCAB_SIZE\n",
        "    p.tokenizer.pad_to_max_length = False\n",
        "\n",
        "    p.source_max_length = p.bucket_upper_bound[-1] + 2\n",
        "    p.target_max_length = p.bucket_upper_bound[-1] + 2\n",
        "    return p\n",
        "\n",
        "\n",
        "# This decorator registers the model in the Lingvo model registry.\n",
        "# This file is lingvo/tasks/punctuator/params/codelab.py,\n",
        "# so the model will be registered as punctuator.codelab.RNMTModel.\n",
        "@model_registry.RegisterSingleTaskModel\n",
        "class RNMTModel(BrownCorpusWPM):\n",
        "  \"\"\"RNMT+ Model.\"\"\"\n",
        "\n",
        "  def Task(self):\n",
        "    p = base_config.SetupRNMTParams(\n",
        "        model.RNMTModel.Params(),\n",
        "        name='punctuator_rnmt',\n",
        "        vocab_size=self._VOCAB_SIZE,\n",
        "        embedding_dim=1024,\n",
        "        hidden_dim=1024,\n",
        "        num_heads=4,\n",
        "        num_encoder_layers=6,\n",
        "        num_decoder_layers=8,\n",
        "        learning_rate=1e-4,\n",
        "        l2_regularizer_weight=1e-5,\n",
        "        lr_warmup_steps=500,\n",
        "        lr_decay_start=400000,\n",
        "        lr_decay_end=1200000,\n",
        "        lr_min=0.5,\n",
        "        ls_uncertainty=0.1,\n",
        "        atten_dropout_prob=0.3,\n",
        "        residual_dropout_prob=0.3,\n",
        "        adam_beta2=0.98,\n",
        "        adam_epsilon=1e-6,\n",
        "    )\n",
        "    p.eval.samples_per_summary = 2466\n",
        "    return p"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ooJYcOA4tcXW"
      },
      "source": [
        "## Model Training\n",
        "\n",
        "The following cell trains the model. Note that this will require approximately 2.5GB of space in `logdir`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "91eOF9Sqy1iG",
        "colab": {}
      },
      "source": [
        "# Start tensorboard (access at http://localhost:6006)\n",
        "import os\n",
        "os.system('lsof -t -i:6006 || tensorboard --logdir=/tmp/punctuator &')\n",
        "\n",
        "!python3 -m lingvo.trainer --model=codelab.RNMTModel --mode=sync --logdir=/tmp/punctuator --saver_max_to_keep=2 --noenable_asserts --run_locally=gpu"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sOQMKZiuwNKr"
      },
      "source": [
        "The following cell evaluates the model. In Lingvo, evaluation is meant to be run alongside training as a separate process that periodically looks for the latest checkpoint and evaluates it. There is only one process in Colab so running this cell will evaluate the current checkpoint then it will block indefinitely waiting for the next checkpoint."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "leWsSPQ6L__H",
        "colab": {}
      },
      "source": [
        "!python3 -m lingvo.trainer --model=codelab.RNMTModel --job=evaler_test --logdir=/tmp/punctuator --run_locally=cpu"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wwqXVIaG2hA_"
      },
      "source": [
        "There is also a Decoder job that can be run the same way. The difference between the Evaler and Decoder varies by model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "g1V7f-Oxw9vp"
      },
      "source": [
        "## Model Inference\n",
        "\n",
        "After the model has been trained for around 10-20k steps (a few hours on GPU), its inference graph can be used to interact with the model using arbitrary inputs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "both",
        "colab_type": "code",
        "id": "7KU0zeUTKS9Z",
        "colab": {}
      },
      "source": [
        "import string\n",
        "\n",
        "from lingvo import compat as tf\n",
        "from lingvo import model_imports\n",
        "from lingvo import model_registry\n",
        "from lingvo.core import inference_graph_exporter\n",
        "from lingvo.core import predictor\n",
        "from lingvo.core.ops.hyps_pb2 import Hypothesis\n",
        "\n",
        "tf.flags.FLAGS.mark_as_parsed()\n",
        "\n",
        "\n",
        "src = \"she asked do you know the way to san jose\" #@param {type:'string'}\n",
        "src = src.lower().translate(str.maketrans('', '', string.punctuation))\n",
        "print(src)\n",
        "\n",
        "checkpoint = tf.train.latest_checkpoint('/tmp/punctuator/train')\n",
        "print('Using checkpoint %s' % checkpoint)\n",
        "\n",
        "# Run inference\n",
        "params = model_registry.GetParams('codelab.RNMTModel', 'Test')\n",
        "inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(params)\n",
        "pred = predictor.Predictor(inference_graph, \n",
        "                           checkpoint=checkpoint, \n",
        "                           device_type='cpu')\n",
        "src_ids, decoded, scores, hyps = pred.Run(\n",
        "    ['src_ids', 'topk_decoded', 'topk_scores', 'topk_hyps'], src_strings=[src])\n",
        "# print(src_ids[0])\n",
        "for text, score in zip(decoded[0].tolist(), scores[0].tolist()):\n",
        "  print(\"%.5f: %s\" % (score, text))\n",
        "# for i, hyp in enumerate(hyps[0]):\n",
        "#   print(\"=======hyp %d=======\" % i)\n",
        "#   print(Hypothesis().FromString(hyp))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_hW7aTI8fM4O"
      },
      "source": [
        "## Additional Resources\n",
        "\n",
        "For more advanced topics or to get a deeper understanding of Lingvo beyond this codelab, see the [paper](https://arxiv.org/abs/1902.08295)."
      ]
    }
  ]
}
